import os
import torch.utils.data as data
import torchvision.transforms as transforms
from collections import defaultdict
import numpy as np
import torch
from PIL import Image, ImageCms
from skimage.segmentation import slic
from skimage.measure import regionprops_table
from skimage.feature import local_binary_pattern
from sklearn.metrics.pairwise import euclidean_distances
from skimage import color
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from fast_slic.avx2 import SlicAvx2
from dataset.constants import *
import matplotlib.pyplot as plt
from scipy import sparse as sp
from scipy.spatial.distance import pdist, squareform
from dataset.attributes import *

class Resize(object):
    def __init__(self, size):
        self.size = size

    def __call__(self, sample):
        img, mask = sample['image'], sample['mask']
        img, mask = img.resize((self.size, self.size), resample=Image.BILINEAR), mask.resize((self.size, self.size),
                                                                                             resample=Image.BILINEAR)
        return {'image': img, 'mask': mask}


class RandomCrop(object):
    def __init__(self, crop_size, resize_size):
        self.crop_size = crop_size
        self.resize_size = resize_size

    def __call__(self, sample):
        img, mask = sample['image'], sample['mask']
        img, mask = img.resize((self.resize_size, self.resize_size), resample=Image.BILINEAR), mask.resize((self.resize_size, self.resize_size), resample=Image.BILINEAR)
        h, w = img.size
        new_h, new_w = self.crop_size, self.crop_size

        top = np.random.randint(0, h - new_h)
        left = np.random.randint(0, w - new_w)
        img = img.crop((left, top, left + new_w, top + new_h))
        mask = mask.crop((left, top, left + new_w, top + new_h))

        return {'image': img, 'mask': mask}


class RandomFlip(object):
    def __init__(self, prob):
        self.prob = prob
        self.flip = transforms.RandomHorizontalFlip(1.)

    def __call__(self, sample):
        if np.random.random_sample() < self.prob:
            img, mask = sample['image'], sample['mask']
            img = self.flip(img)
            mask = self.flip(mask)
            return {'image': img, 'mask': mask}
        else:
            return sample


class ToTensorSP(object):
    def __init__(self, num_seg, compactness):
        self.tensor = transforms.ToTensor()
        self.num_seg = num_seg
        self.compactness = compactness

    def __call__(self, sample):
        img, mask = sample['image'], sample['mask']
        img_np = np.array(img)
        img_size = img_np.shape[1]
        mask_np = np.array(mask)/255.
        segments = slic(img_np, n_segments=self.num_seg,
            compactness=self.compactness,
            max_num_iter=3,
            convert2lab=True,
            enforce_connectivity=True,
            slic_zero=False)
   
        vs_right = np.vstack([segments[:,:-1].ravel(), segments[:,1:].ravel()])
        vs_below = np.vstack([segments[:-1,:].ravel(), segments[1:,:].ravel()])
        vs_diagonal_r = np.vstack([segments[:-1,:-1].ravel(), segments[1:,1:].ravel()])
        vs_diagonal_l = np.vstack([segments[1:,:-1].ravel(), segments[:-1,1:].ravel()])
        bneighbors = np.unique(np.hstack([vs_right, vs_below, vs_diagonal_r, vs_diagonal_l]), axis=1)

        regions = regionprops_table(segments, intensity_image=img_np, properties=('label', 'centroid', 'area', 'intensity_mean',
                                                                                     'coords'), extra_properties=[image_stdev, eccen])#, polarize])
                    
        seq_len = len(regions['label'])
        features = np.zeros([self.num_seg, 15])
        seq_mask = np.zeros([self.num_seg])
        label = regions['label']
        features[label-1, 0] = regions['centroid-0']
        features[label-1, 1] = regions['centroid-1']
        features[label-1, 2] = regions['area'] / (img_size**2)
        features[label-1, 3] = regions['intensity_mean-0']/255.
        features[label-1, 4] = regions['intensity_mean-1']/255.
        features[label-1, 5] = regions['intensity_mean-2']/255.
        features[label-1, 6] = regions['image_stdev-0']/255.
        features[label-1, 7] = regions['image_stdev-1']/255.
        features[label-1, 8] = regions['image_stdev-2']/255.
        features[label-1, 9] = regions['eccen-0']
        features[label-1, 10] = regions['eccen-1']
        features[label-1, 11] = regions['eccen-2']
        features[label-1, 12] = regions['eccen-3']
        features[label-1, 13] = regions['eccen-4']
        features[label-1, 14] = regions['eccen-5']
    


        for ind, coord in zip(regions['label'], regions['coords']):
            seq_mask[ind-1] = 1 if np.sum(mask_np[coord[:, 0], coord[:, 1]])/len(coord[:, 0]) >= 0.5 else 0

        neighbor_array = np.zeros([self.num_seg, self.num_seg])
        # eye = np.eye(self.num_seg)
        neighbor_array[bneighbors[0]-1, bneighbors[1]-1] = 1
        neighbor_array[bneighbors[1]-1, bneighbors[0]-1] = 1
        # neighbor_array -= eye


        # A = neighbor_array.astype(float)
        # N = sp.diags(np.sum(A, axis=0)** -0.5, dtype=float)
        # L = eye - N * A * N

        # # Eigenvectors with numpy
        # EigVal, EigVec = np.linalg.eig(L)
        # idx = EigVal.argsort() # increasing order
        # EigVal, EigVec = EigVal[idx], np.real(EigVec[:,idx])
        # pos_enc = torch.from_numpy(EigVec[:,1:POS_EMBEDDING+1]).float() 

        # histogram_r = np.zeros([self.num_seg, BINS])
        # histogram_g = np.zeros([self.num_seg, BINS])
        # histogram_b = np.zeros([self.num_seg, BINS])
        # for i in range(BINS):
        #     histogram_r[label-1, i] = regions[f'hist-{i}-0']
        #     histogram_g[label-1, i] = regions[f'hist-{i}-1']
        #     histogram_b[label-1, i] = regions[f'hist-{i}-2']

        # histogram_r = histogram_r/np.sum(histogram_r, axis=1, keepdims=True)
        # histogram_g = histogram_g/np.sum(histogram_g, axis=1, keepdims=True)
        # histogram_b = histogram_b/np.sum(histogram_b, axis=1, keepdims=True)
        
        # histogram_r_sq = 1-pdist(histogram_r, lambda u, v: np.sqrt(u*v).sum())
        # histogram_g_sq = 1-pdist(histogram_g, lambda u, v: np.sqrt(u*v).sum())
        # histogram_b_sq = 1-pdist(histogram_b, lambda u, v: np.sqrt(u*v).sum())

        # # spatial_distances = euclidean_distances(features[:, :2], features[:, :2])/np.sqrt(300**2+300**2)
        # spatial_distances_x = (features[:, 0:1] - features[:, 0:1].T)/300.
        # spatial_distances_y = (features[:, 1:2] - features[:, 1:2].T)/300.
        # # ind = np.argsort(distances, axis=1)
        # # neighbor_array = ind <= NUM_NEIGHBOURS
        # # neighbor_array = np.zeros([self.num_seg, self.num_seg])
        
        # edge_features = np.stack((spatial_distances_x, spatial_distances_y, squareform(histogram_r_sq), squareform(histogram_g_sq), squareform(histogram_b_sq)), axis=2)


        features, neighbor_array, seq_mask, segments, mask, img = torch.tensor(features).float(), torch.tensor(neighbor_array).float(), torch.tensor(seq_mask).float(), torch.tensor(segments), self.tensor(mask), self.tensor(img)
        # edge_features = torch.from_numpy(edge_features).float()
        return {'features': features, 'seq_mask': seq_mask, 'segments': segments, 'mask': mask, 'img': img, 'neighbor_array': neighbor_array, }
        # 'edge_features':edge_features}

class ToTensorSPLAP(object):
    def __init__(self, num_seg, compactness):
        self.tensor = transforms.ToTensor()
        self.num_seg = num_seg
        self.compactness = compactness

    def __call__(self, sample):
        img, mask = sample['image'], sample['mask']
        img_np = np.array(img)
        img_size = img_np.shape[1]
        mask_np = np.array(mask)/255.
        segments = slic(img_np, n_segments=self.num_seg,
            compactness=self.compactness,
            max_num_iter=3,
            convert2lab=True,
            enforce_connectivity=False,
            slic_zero=False)
   
        vs_right = np.vstack([segments[:,:-1].ravel(), segments[:,1:].ravel()])
        vs_below = np.vstack([segments[:-1,:].ravel(), segments[1:,:].ravel()])
        vs_diagonal_r = np.vstack([segments[:-1,:-1].ravel(), segments[1:,1:].ravel()])
        vs_diagonal_l = np.vstack([segments[1:,:-1].ravel(), segments[:-1,1:].ravel()])
        bneighbors = np.unique(np.hstack([vs_right, vs_below, vs_diagonal_r, vs_diagonal_l]), axis=1)

        regions = regionprops_table(segments, intensity_image=img_np, properties=('label', 'centroid', 'area', 'intensity_mean',
                                                                                     'coords'), extra_properties=[image_stdev])#, polarize])
                    
        seq_len = len(regions['label'])
        features = np.zeros([self.num_seg, 9])
        seq_mask = np.zeros([self.num_seg])
        label = regions['label']
        features[label-1, 0] = regions['centroid-0']
        features[label-1, 1] = regions['centroid-1']
        features[label-1, 2] = regions['area'] / (img_size**2)
        features[label-1, 3] = regions['intensity_mean-0']/255.
        features[label-1, 4] = regions['intensity_mean-1']/255.
        features[label-1, 5] = regions['intensity_mean-2']/255.
        features[label-1, 6] = regions['image_stdev-0']/255.
        features[label-1, 7] = regions['image_stdev-1']/255.
        features[label-1, 8] = regions['image_stdev-2']/255.


        for ind, coord in zip(regions['label'], regions['coords']):
            seq_mask[ind-1] = 1 if np.sum(mask_np[coord[:, 0], coord[:, 1]])/len(coord[:, 0]) >= 0.5 else 0

        neighbor_array = np.zeros([self.num_seg, self.num_seg])
        eye = np.eye(self.num_seg)
        neighbor_array[bneighbors[0]-1, bneighbors[1]-1] = 1
        neighbor_array[bneighbors[1]-1, bneighbors[0]-1] = 1
        neighbor_array -= eye


        A = neighbor_array.astype(float)
        N = sp.diags(np.sum(A, axis=0).clip(1) ** -0.5, dtype=float)
        L = eye - N * A * N


        # Eigenvectors with numpy
        EigVal, EigVec = np.linalg.eig(L)
        idx = EigVal.argsort() # increasing order
        EigVal, EigVec = EigVal[idx], np.real(EigVec[:,idx])
        pos_enc = torch.from_numpy(EigVec[:,1:]).float() 

     

        features, neighbor_array, seq_mask, segments, mask, img = torch.tensor(features).float(), torch.tensor(neighbor_array).float(), torch.tensor(seq_mask).float(), torch.tensor(segments), self.tensor(mask), self.tensor(img)
        return {'features': features, 'seq_mask': seq_mask, 'segments': segments, 'mask': mask, 'img': img, 'neighbor_array': neighbor_array, 'pos_enc': pos_enc}
 


class ToTensorSPFFT(object):
    def __init__(self, num_seg, compactness, coeff):
        self.tensor = transforms.ToTensor()
        self.num_seg = num_seg
        self.coeff = coeff
        self.compactness = compactness
        
        def fourier_descriptors(region):
            region = (region*255).astype(np.uint8)
            contour, hierarchy = cv2.findContours(region, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
            points = contour[0][:, 0, :]
            xi, yi = resample_2d(points, RESAMPLE_POINTS)
            contour_array = np.stack((xi, yi), axis=1)


            contour_complex = np.empty(contour_array.shape[:-1], dtype=complex)
            contour_complex.real = contour_array[:, 0]
            contour_complex.imag = contour_array[:, 1]
            fourier_result = np.fft.fft(contour_complex)

            fourier_result_front = fourier_result[1:coeff//2]
            fourier_result_back = fourier_result[-coeff//2:]
            fourier_result = np.concatenate((fourier_result_front, fourier_result_back), axis=0)

            amp = abs(fourier_result)
            phase = np.arctan2(fourier_result.imag, fourier_result.real)

            # return np.array(amp)
            return np.concatenate((amp, phase))

        self.fourier_descriptors = fourier_descriptors


    def __call__(self, sample):
        img, mask = sample['image'], sample['mask']
        img_np = np.array(img)
        img_size = img_np.shape[1]
        mask_np = np.array(mask)/255.
        segments = slic(img_np, n_segments=self.num_seg,
            compactness=self.compactness,
            max_num_iter=3,
            convert2lab=True,
            enforce_connectivity=False,
            slic_zero=False)


        vs_right = np.vstack([segments[:,:-1].ravel(), segments[:,1:].ravel()])
        vs_below = np.vstack([segments[:-1,:].ravel(), segments[1:,:].ravel()])
        vs_diagonal_r = np.vstack([segments[:-1,:-1].ravel(), segments[1:,1:].ravel()])
        vs_diagonal_l = np.vstack([segments[1:,:-1].ravel(), segments[:-1,1:].ravel()])
        bneighbors = np.unique(np.hstack([vs_right, vs_below, vs_diagonal_r, vs_diagonal_l]), axis=1)
    

        regions = regionprops_table(segments, intensity_image=img_np, properties=('label', 'centroid', 'intensity_mean',
                                                                                    'coords'), extra_properties=[image_stdev, self.fourier_descriptors])#, polarize])

        seq_len = len(regions['label'])
        features = np.zeros([self.num_seg, 8+((self.coeff//2)*2-1)*2])
        seq_mask = np.zeros([self.num_seg])
        label = regions['label']
        features[label-1, 0] = regions['centroid-0']
        features[label-1, 1] = regions['centroid-1']
        features[label-1, 2] = regions['intensity_mean-0']/255.
        features[label-1, 3] = regions['intensity_mean-1']/255.
        features[label-1, 4] = regions['intensity_mean-2']/255.
        features[label-1, 5] = regions['image_stdev-0']/255.
        features[label-1, 6] = regions['image_stdev-1']/255.
        features[label-1, 7] = regions['image_stdev-2']/255.
        for i in range(((self.coeff//2)*2-1)*2):
            features[label-1, 8+i] = regions[f'fourier_descriptors-{i}']


        for ind, coord in zip(regions['label'], regions['coords']):
            seq_mask[ind-1] = 1 if np.sum(mask_np[coord[:, 0], coord[:, 1]])/len(coord[:, 0]) >= 0.5 else 0

        neighbor_array = np.zeros([self.num_seg, self.num_seg])
        # eye = np.eye(self.num_seg)
        neighbor_array[bneighbors[0]-1, bneighbors[1]-1] = 1
        neighbor_array[bneighbors[1]-1, bneighbors[0]-1] = 1

        # neighbor_array -= eye

        spatial_distances_x = (features[:, 0:1] - features[:, 0:1].T)
        spatial_distances_y = (features[:, 1:2] - features[:, 1:2].T)
        edge_features = np.stack((spatial_distances_x, spatial_distances_y), axis=2)

      
        edge_features = torch.from_numpy(edge_features).float()
        features, neighbor_array, seq_mask, segments, mask, img = torch.tensor(features).float(), torch.tensor(neighbor_array).float(), torch.tensor(seq_mask).float(), torch.tensor(segments), self.tensor(mask), self.tensor(img)

        return {'features': features, 'seq_mask': seq_mask, 'segments': segments, 'mask': mask, 'img': img, 'neighbor_array': neighbor_array, 'edge_features': edge_features}

class ToTensorSPContour(object):
    def __init__(self, num_seg):
        self.tensor = transforms.ToTensor()
        self.num_seg = num_seg

    def __call__(self, sample):
        img, mask = sample['image'], sample['mask']
        img_np = np.array(img)
        img_size = img_np.shape[1]
        mask_np = np.array(mask)/255.
        segments = slic(img_np, n_segments=self.num_seg,
            compactness=COMPACTNESS,
            max_num_iter=10,
            convert2lab=True,
            enforce_connectivity=True,
            slic_zero=False)


        vs_right = np.vstack([segments[:,:-1].ravel(), segments[:,1:].ravel()])
        vs_below = np.vstack([segments[:-1,:].ravel(), segments[1:,:].ravel()])
        vs_diagonal_r = np.vstack([segments[:-1,:-1].ravel(), segments[1:,1:].ravel()])
        vs_diagonal_l = np.vstack([segments[1:,:-1].ravel(), segments[:-1,1:].ravel()])
        bneighbors = np.unique(np.hstack([vs_right, vs_below, vs_diagonal_r, vs_diagonal_l]), axis=1)
    

        regions = regionprops_table(segments, intensity_image=img_np, properties=('label', 'centroid', 'area', 'intensity_mean',
                                                                                    'coords'), extra_properties=[image_stdev, contours_euc])#, polarize])


        seq_len = len(regions['label'])
        features = np.zeros([self.num_seg, 8+(RESAMPLE_POINTS*2)])
        seq_mask = np.zeros([self.num_seg])
        label = regions['label']
        features[label-1, 0] = regions['centroid-0']
        features[label-1, 1] = regions['centroid-1']
        features[label-1, 2] = regions['intensity_mean-0']/255.
        features[label-1, 3] = regions['intensity_mean-1']/255.
        features[label-1, 4] = regions['intensity_mean-2']/255.
        features[label-1, 5] = regions['image_stdev-0']/255.
        features[label-1, 6] = regions['image_stdev-1']/255.
        features[label-1, 7] = regions['image_stdev-2']/255.
        for i in range(RESAMPLE_POINTS):
            features[label-1, 8+i] = regions[f'contours_euc-{i}-0']
            features[label-1, 8+RESAMPLE_POINTS+i] = regions[f'contours_euc-{i}-1']


        for ind, coord in zip(regions['label'], regions['coords']):
            seq_mask[ind-1] = 1 if np.sum(mask_np[coord[:, 0], coord[:, 1]])/len(coord[:, 0]) >= 0.5 else 0

        neighbor_array = np.zeros([self.num_seg, self.num_seg])
        # eye = np.eye(self.num_seg)
        neighbor_array[bneighbors[0]-1, bneighbors[1]-1] = 1
        neighbor_array[bneighbors[1]-1, bneighbors[0]-1] = 1
        # neighbor_array -= eye


        features, neighbor_array, seq_mask, segments, mask, img = torch.tensor(features).float(), torch.tensor(neighbor_array).float(), torch.tensor(seq_mask).float(), torch.tensor(segments), self.tensor(mask), self.tensor(img)

        return {'features': features, 'seq_mask': seq_mask, 'segments': segments, 'mask': mask, 'img': img, 'neighbor_array': neighbor_array, }


class ToTensorSPCNN(object):
    def __init__(self, num_seg):
        self.tensor = transforms.ToTensor()
        self.num_seg = num_seg

    def __call__(self, sample):
        img, mask = sample['image'], sample['mask']


        img_lab = np.array(color.rgb2lab(img))
        img_hsv = np.array(img.convert('HSV'))
        img_gray = np.array(img.convert('L'))
        img_np = np.array(img.convert('RGB'))
        img_size = img_np.shape[1]
        mask_np = np.array(mask)/255.
        segments = slic(img_np, n_segments=self.num_seg,
            compactness=10,
            max_num_iter=10,
            convert2lab=True,
            enforce_connectivity=False,
            slic_zero=True, min_size_factor=0.)
        # slic = SlicAvx2(num_components=self.num_seg, compactness=10)
        # segments = slic.iterate(img_np)
    
        lbp_np = local_binary_pattern(img_gray, 57, 8)
        regions = regionprops_table(segments, intensity_image=img_np, properties=('label', 'centroid', 'intensity_mean','coords'))

        regions_lbp = regionprops_table(segments, intensity_image=lbp_np, extra_properties=[lbp])
        regions_lab = regionprops_table(segments, intensity_image=img_lab, properties=('label', 'intensity_mean'))
        regions_hsv = regionprops_table(segments, intensity_image=img_hsv, properties=('label', 'intensity_mean'))
 
        seq_len = len(regions['label'])
        features = np.zeros([self.num_seg, 70])
        seq_mask = np.zeros([self.num_seg])
        label = regions['label']
        features[label-1, 0] = regions['centroid-0']/300.
        features[label-1, 1] = regions['centroid-1']/300.
        features[label-1, 2] = regions['intensity_mean-0']
        features[label-1, 3] = regions['intensity_mean-1']
        features[label-1, 4] = regions['intensity_mean-2']
        features[label-1, 5] = regions_lab['intensity_mean-0']
        features[label-1, 6] = regions_lab['intensity_mean-1']
        features[label-1, 7] = regions_lab['intensity_mean-2']
        features[label-1, 8] = regions_hsv['intensity_mean-0']
        features[label-1, 9] = regions_hsv['intensity_mean-1']
        features[label-1, 10] = regions_hsv['intensity_mean-2']

        for ind in range(59):
            features[label-1, ind+11] = regions_lbp[f'lbp-{ind}']


        for ind, coord in zip(regions['label'], regions['coords']):
            seq_mask[ind-1] = np.sum(mask_np[coord[:, 0], coord[:, 1]])/len(coord[:, 0])



        features, seq_mask, segments, mask, img = torch.tensor(features).float(), torch.tensor(seq_mask).float(), torch.tensor(segments), self.tensor(mask), self.tensor(img)
        return {'features': features, 'seq_mask': seq_mask, 'segments': segments, 'mask': mask, 'img': img}

class ToTensorRaw(object):
    def __init__(self):
        self.tensor = transforms.ToTensor()

    def __call__(self, sample):
        img, mask = sample['image'], sample['mask']
        img, mask = self.tensor(img), self.tensor(mask)
        return {'image': img, 'mask': mask}

class SPDataset(data.Dataset):
    def __init__(self, root_dir, num_seg, size, compactness, data_augmentation=True, dataloader=None, coeff=None):
        self.root_dir = root_dir
        self.image_list = sorted(os.listdir('{}/Image'.format(root_dir)))
        self.mask_list = sorted(os.listdir('{}/Mask'.format(root_dir)))
        if dataloader == 'SP':
            totensor = ToTensorSP(num_seg, compactness)
        elif dataloader == 'SPFFT':
            totensor = ToTensorSPFFT(num_seg, compactness, coeff)
        elif dataloader == 'SPLAP':
            totensor = ToTensorSPLAP(num_seg, compactness)
        elif dataloader == 'SPCNN':
            totensor = ToTensorSPCNN(num_seg, compactness)
        elif dataloader == 'SPContour':
            totensor = ToTensorSPContour(num_seg, compactness)
        else:
            raise 'Unrecongized dataloader'

        self.transform = transforms.Compose(
            [RandomFlip(0.5),
             RandomCrop(size, int(size*1.14)),
             totensor])
        if not data_augmentation:
            self.transform = transforms.Compose([Resize(size), totensor])

        self.root_dir = root_dir
        self.data_augmentation = data_augmentation

    def __len__(self):
        return len(self.image_list)

    def __getitem__(self, item):
        img_name = '{}/Image/{}'.format(self.root_dir, self.image_list[item])
        mask_name = '{}/Mask/{}'.format(self.root_dir, self.mask_list[item])
        img = Image.open(img_name)
        mask = Image.open(mask_name)
        img = img.convert('RGB')
        mask = mask.convert('L')
        sample = {'image': img, 'mask': mask}

        sample = self.transform(sample)
        sample['file_name'] = self.image_list[item]
        return sample


class DUTSDataset(data.Dataset):
    def __init__(self, root_dir, size, train=True, data_augmentation=True):
        self.root_dir = root_dir
        self.image_list = sorted(os.listdir('{}/Image'.format(root_dir)))
        self.mask_list = sorted(os.listdir('{}/Mask'.format(root_dir)))
        self.transform = transforms.Compose(
            [RandomFlip(0.5),
             RandomCrop(size, int(size*1.2)),
             ToTensorRaw()])
        if not (train and data_augmentation):
            self.transform = transforms.Compose([Resize(size), ToTensorRaw()])
        self.root_dir = root_dir


    def __len__(self):
        return len(self.image_list)

    def __getitem__(self, item):
        img_name = '{}/Image/{}'.format(self.root_dir, self.image_list[item])
        mask_name = '{}/Mask/{}'.format(self.root_dir, self.mask_list[item])
        img = Image.open(img_name)
        mask = Image.open(mask_name)
        img = img.convert('RGB')
        mask = mask.convert('L')
        sample = {'image': img, 'mask': mask}

        sample = self.transform(sample)
        return sample

class SPDataModule(pl.LightningDataModule):

    def __init__(self, **kwargs):
        super().__init__()

        self.train_dir = kwargs.get('dataset_tr')
        self.val_dir = kwargs.get('dataset_val')
        self.test_dir = kwargs.get('dataset_test')
        self.batch_size = kwargs.get('batch_size')
        self.num_workers = kwargs.get('num_workers', 0)
        self.num_seg = kwargs.get('num_seg', 600)
        self.res = kwargs.get('size')
        self.dataloader = kwargs.get('dataloader')
        self.coeff = kwargs.get('coeff')
        self.compactness = kwargs.get('compactness')

        
    def train_dataloader(self):
        data_train = SPDataset(self.train_dir, self.num_seg, self.res, self.compactness, True, self.dataloader, self.coeff)
        return DataLoader(
                data_train, batch_size=self.batch_size, 
                num_workers=self.num_workers, shuffle=True, pin_memory=False)

    def val_dataloader(self):
        data_val = SPDataset(self.val_dir, self.num_seg, self.res, self.compactness, False, self.dataloader, self.coeff)
        return DataLoader(
                data_val, batch_size=self.batch_size, 
                num_workers=self.num_workers, pin_memory=False)

    def test_dataloader(self):
        data_test = SPDataset(self.test_dir, self.num_seg, self.res,  self.compactness, False, self.dataloader, self.coeff)
        return DataLoader(
                data_test, batch_size=self.batch_size, 
                num_workers=self.num_workers, pin_memory=False)




class DUTSDataModule(pl.LightningDataModule):

    def __init__(self, **kwargs):
        super().__init__()

        self.train_dir = kwargs.get('dataset_tr')
        self.val_dir = kwargs.get('dataset_val')
        self.test_dir = kwargs.get('dataset_test')
        self.batch_size = kwargs.get('batch_size')
        self.num_workers = kwargs.get('num_workers', 0)
        self.image_size = kwargs.get('size')

        
    def train_dataloader(self):
        data_train = DUTSDataset(self.train_dir, self.image_size, True, True)
        return DataLoader(
                data_train, batch_size=self.batch_size, 
                num_workers=self.num_workers, shuffle=True, pin_memory=False)

    def val_dataloader(self):
        data_val = DUTSDataset(self.val_dir, self.image_size, False, False)
        return DataLoader(
                data_val, batch_size=self.batch_size, 
                num_workers=self.num_workers, pin_memory=False)

    def test_dataloader(self):
        data_test = DUTSDataset(self.test_dir, self.image_size, False, False)
        return DataLoader(
                data_test, batch_size=self.batch_size, 
                num_workers=self.num_workers, pin_memory=False)